{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matchzoo version 2.1.0\n",
      "\n",
      "data loading ...\n",
      "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n",
      "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n",
      "loading embedding ...\n",
      "embedding loaded as `glove_embedding`\n"
     ]
    }
   ],
   "source": [
    "%run init.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 2118/2118 [00:00<00:00, 3587.72it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 18841/18841 [00:04<00:00, 4528.13it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 592156.77it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 432217.30it/s]\n",
      "Building Vocabulary from a datapack.: 100%|██████████| 1614998/1614998 [00:00<00:00, 4239505.32it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 2118/2118 [00:00<00:00, 2709.71it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 18841/18841 [00:11<00:00, 1656.57it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 122/122 [00:00<00:00, 1120.91it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 1115/1115 [00:00<00:00, 1895.34it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 237/237 [00:00<00:00, 1910.44it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 2300/2300 [00:01<00:00, 1630.79it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.DSSMPreprocessor()\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack_raw)\n",
    "valid_pack_processed = preprocessor.transform(dev_pack_raw)\n",
    "test_pack_processed = preprocessor.transform(test_pack_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7f0d7fe29fd0>,\n",
       " 'vocab_size': 9645,\n",
       " 'embedding_input_dim': 9645,\n",
       " 'input_shapes': [(9645,), (9645,)]}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preprocessor.context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankCrossEntropyLoss(num_neg=4))\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 9645)         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 9645)         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 300)          2893800     text_left[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_5 (Dense)                 (None, 300)          2893800     text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 300)          90300       dense_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_6 (Dense)                 (None, 300)          90300       dense_5[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 300)          90300       dense_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_7 (Dense)                 (None, 300)          90300       dense_6[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_4 (Dense)                 (None, 128)          38528       dense_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_8 (Dense)                 (None, 128)          38528       dense_7[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dot_1 (Dot)                     (None, 1)            0           dense_4[0][0]                    \n",
      "                                                                 dense_8[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_9 (Dense)                 (None, 1)            2           dot_1[0][0]                      \n",
      "==================================================================================================\n",
      "Total params: 6,225,858\n",
      "Trainable params: 6,225,858\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.DSSM()\n",
    "model.params['input_shapes'] = preprocessor.context['input_shapes']\n",
    "model.params['task'] = ranking_task\n",
    "model.params['mlp_num_layers'] = 3\n",
    "model.params['mlp_num_units'] = 300\n",
    "model.params['mlp_num_fan_out'] = 128\n",
    "model.params['mlp_activation_func'] = 'relu'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()\n",
    "\n",
    "append_params_to_readme(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_x, pred_y = test_pack_processed[:].unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: PairDataGenerator will be deprecated in MatchZoo v2.2. Use `DataGenerator` with callbacks instead.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "32"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=1, num_neg=4, batch_size=32, shuffle=True)\n",
    "len(train_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "32/32 [==============================] - 7s 215ms/step - loss: 1.3325\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4431849853601904 - normalized_discounted_cumulative_gain@5(0.0): 0.5295386323998266 - mean_average_precision(0.0): 0.48303488812718776\n",
      "Epoch 2/20\n",
      "32/32 [==============================] - 6s 176ms/step - loss: 1.3159\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4353814849661657 - normalized_discounted_cumulative_gain@5(0.0): 0.5032525911610362 - mean_average_precision(0.0): 0.4776049822282439\n",
      "Epoch 3/20\n",
      "32/32 [==============================] - 5s 171ms/step - loss: 1.2955\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4088637099689691 - normalized_discounted_cumulative_gain@5(0.0): 0.48351010067595823 - mean_average_precision(0.0): 0.4432379861560312\n",
      "Epoch 4/20\n",
      "32/32 [==============================] - 6s 173ms/step - loss: 1.2726\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.46569627211992487 - normalized_discounted_cumulative_gain@5(0.0): 0.5305277638291452 - mean_average_precision(0.0): 0.4903964896023526\n",
      "Epoch 5/20\n",
      "32/32 [==============================] - 6s 172ms/step - loss: 1.2439\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44778538209256513 - normalized_discounted_cumulative_gain@5(0.0): 0.5104380434420628 - mean_average_precision(0.0): 0.47615129143046664\n",
      "Epoch 6/20\n",
      "32/32 [==============================] - 6s 172ms/step - loss: 1.2202\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4452573045503587 - normalized_discounted_cumulative_gain@5(0.0): 0.5137975378931312 - mean_average_precision(0.0): 0.4742872412051932\n",
      "Epoch 7/20\n",
      "32/32 [==============================] - 5s 170ms/step - loss: 1.2038\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.41264292792428936 - normalized_discounted_cumulative_gain@5(0.0): 0.4740615140630128 - mean_average_precision(0.0): 0.45294026408574084\n",
      "Epoch 8/20\n",
      "32/32 [==============================] - 6s 172ms/step - loss: 1.1848\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45527149721829696 - normalized_discounted_cumulative_gain@5(0.0): 0.5229678873030444 - mean_average_precision(0.0): 0.48490323375232625\n",
      "Epoch 9/20\n",
      "32/32 [==============================] - 5s 171ms/step - loss: 1.1504 3\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4401749964298954 - normalized_discounted_cumulative_gain@5(0.0): 0.5202410581724496 - mean_average_precision(0.0): 0.47967943778482564\n",
      "Epoch 10/20\n",
      "32/32 [==============================] - 5s 172ms/step - loss: 1.1314\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44883790476151675 - normalized_discounted_cumulative_gain@5(0.0): 0.5215788412779597 - mean_average_precision(0.0): 0.48274548802838624\n",
      "Epoch 11/20\n",
      "32/32 [==============================] - 6s 173ms/step - loss: 1.1109\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45835958548802597 - normalized_discounted_cumulative_gain@5(0.0): 0.5254562351939174 - mean_average_precision(0.0): 0.48819163523037407\n",
      "Epoch 12/20\n",
      "32/32 [==============================] - 6s 174ms/step - loss: 1.0915\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4540812972538116 - normalized_discounted_cumulative_gain@5(0.0): 0.502728792326375 - mean_average_precision(0.0): 0.48229166522394096\n",
      "Epoch 13/20\n",
      "32/32 [==============================] - 6s 173ms/step - loss: 1.0805\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4462255256302118 - normalized_discounted_cumulative_gain@5(0.0): 0.5097488218798687 - mean_average_precision(0.0): 0.4751972950775518\n",
      "Epoch 14/20\n",
      "32/32 [==============================] - 6s 174ms/step - loss: 1.0575\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4263585495923587 - normalized_discounted_cumulative_gain@5(0.0): 0.5014903707963352 - mean_average_precision(0.0): 0.46364289738480496\n",
      "Epoch 15/20\n",
      "32/32 [==============================] - 6s 179ms/step - loss: 1.0396\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.43936705108731194 - normalized_discounted_cumulative_gain@5(0.0): 0.5218713927469146 - mean_average_precision(0.0): 0.47233172236473137\n",
      "Epoch 16/20\n",
      "32/32 [==============================] - 6s 182ms/step - loss: 1.0156\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45080782122574514 - normalized_discounted_cumulative_gain@5(0.0): 0.5181271382497495 - mean_average_precision(0.0): 0.4832342072703635\n",
      "Epoch 17/20\n",
      "32/32 [==============================] - 6s 175ms/step - loss: 0.9932\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.423108628561739 - normalized_discounted_cumulative_gain@5(0.0): 0.49596605935842625 - mean_average_precision(0.0): 0.4667294180948952\n",
      "Epoch 18/20\n",
      "32/32 [==============================] - 5s 172ms/step - loss: 0.9800\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4378084124127128 - normalized_discounted_cumulative_gain@5(0.0): 0.5098753091251295 - mean_average_precision(0.0): 0.4734416114488085\n",
      "Epoch 19/20\n",
      "32/32 [==============================] - 6s 172ms/step - loss: 0.9662\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4504450479915345 - normalized_discounted_cumulative_gain@5(0.0): 0.519107636100811 - mean_average_precision(0.0): 0.48712867088141415\n",
      "Epoch 20/20\n",
      "32/32 [==============================] - 6s 172ms/step - loss: 0.9512\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45663442312293695 - normalized_discounted_cumulative_gain@5(0.0): 0.5363645153841258 - mean_average_precision(0.0): 0.4956098197015037\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=20, callbacks=[evaluate], workers=5, use_multiprocessing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
